# -*- coding: utf-8 -*-
"""
Created on Fri Apr 15 14:22:23 2022

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
from torch.autograd import Variable
import time
import pandas as pd

import networkx as nx
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
from datetime import datetime
from torch.nn.utils import clip_grad_norm_
import copy
import random
import pickle
import os

from models import PhyCRNet,MaskedMSELoss,MaskedL1Loss,weightMSELoss
from data_utils import PrepareConcDataset
from utils import plot_conc,read_logs


#%%load args and model

log_path='exp/exp9/log_conc/'
# dirs={
#       "log_dirs":log_dirs,
#       "model_dirs":model_dirs,
#       "args_dirs":args_dirs,
#       "events_dirs":events_dirs,
#       "orders":orders
#       }

dirs=read_logs(log_path)

# for i in range(len(dirs['log_dirs'])):

i=4
with open(dirs['args_dirs'][i],'rb') as f:
    args=pickle.load(f)
    
args['cuda'] = torch.cuda.is_available()

#if cuda is available use cuda
if args['cuda']:
    device = torch.device('cuda')
    
train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
val_input,val_label1,val_label1_coe,val_label2,val_label2_mask,\
mmin_list,dev_list,conc_name,edge_list=PrepareConcDataset(args['time_window'],args['res'],
                                                          args['train_ratio'],args['buffer'])

train_input=train_input.to(device=device)
train_label1=train_label1.to(device=device)
train_label1_coe=train_label1_coe.to(device=device)
train_label2=train_label2.to(device=device)
train_label2_mask=train_label2_mask.to(device=device)
val_input=val_input.to(device=device)
val_label1=val_label1.to(device=device)
val_label1_coe=val_label1_coe.to(device=device)
val_label2=val_label2.to(device=device)
val_label2_mask=val_label2_mask.to(device=device)


#reconfig model
model = PhyCRNet(args['input_dim'],
                 args['hidden_dim'],
                 args['output_dim'],
                 args['input_kernel_size'],
                 args['input_stride'],
                 args['padding_mode'],
                 args['num_layers'], 
                 args['warmup_updates'],
                 args['tot_updates'],
                 args['peak_lr'],
                 args['end_lr'],
                 args['weight_decay'],
                 args['optim'])
device = torch.device('cuda')
model=model.to(device=device)

model_path=dirs['model_dirs'][i]
model.load_state_dict(torch.load(model_path))


#%%
weight_list = []
bias_list=[]
for name, param in model.named_parameters():
    if 'weight' in name:
        weight = (name, param.data)
        weight_list.append(weight)
    
    if 'bias' in name:
        bias=(name,param.data)
        bias_list.append(bias)

total_weight=0
total_bias=0
for i in range(len(weight_list)):
    total_weight=total_weight+weight_list[i][1].numel()
    
for i in range(len(bias_list)):
    total_bias=total_bias+bias_list[i][1].numel()
print(f'the total number of weight is {total_weight}')
print(f'the total number of bias is {total_bias}')


#%%

import shap

e = shap.GradientExplainer(model, train_input[:2,...])

shap_values = e.shap_values(train_input[2:,...])
















